# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for source_utils."""

import ast
import os
import sys
import tempfile
import zipfile

import numpy as np

from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.debug.lib import source_utils
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import math_ops
# Import resource_variable_ops for the variables-to-tensor implicit conversion.
from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_loop
from tensorflow.python.platform import googletest
from tensorflow.python.util import tf_inspect


def line_number_above():
  """Get lineno of the AST node immediately above this function's call site.

  It is assumed that there is no empty line(s) between the call site and the
  preceding AST node.

  Returns:
    The lineno of the preceding AST node, at the same level of the AST.
    If the preceding AST spans multiple lines:
      - In Python 3.8+, the lineno of the first line is returned.
      - In older Python versions, the lineno of the last line is returned.
  """
  # https://bugs.python.org/issue12458: In Python 3.8, traceback started
  # to return the lineno of the first line of a multi-line continuation block,
  # instead of that of the last line. Therefore, in Python 3.8+, we use `ast` to
  # get the lineno of the first line.
  call_site_lineno = tf_inspect.stack()[1][2]
  if sys.version_info < (3, 8):
    return call_site_lineno - 1
  else:
    with open(__file__, "rb") as f:
      source_text = f.read().decode("utf-8")
    source_tree = ast.parse(source_text)
    prev_node = _find_preceding_ast_node(source_tree, call_site_lineno)
    return prev_node.lineno


def _find_preceding_ast_node(node, lineno):
  """Find the ast node immediately before and not including lineno."""
  for i, child_node in enumerate(node.body):
    if child_node.lineno == lineno:
      return node.body[i - 1]
    if hasattr(child_node, "body"):
      found_node = _find_preceding_ast_node(child_node, lineno)
      if found_node:
        return found_node


class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):

  def setUp(self):
    self.curr_file_path = os.path.normpath(os.path.abspath(__file__))

  def tearDown(self):
    ops.reset_default_graph()

  def testGuessedBaseDirIsProbablyCorrect(self):
    # In the non-pip world, code resides in "tensorflow/"
    # In the pip world, after virtual pip, code resides in "tensorflow_core/"
    # So, we have to check both of them
    self.assertIn(
        os.path.basename(source_utils._TENSORFLOW_BASEDIR),
        ["tensorflow", "tensorflow_core"])

  def testUnitTestFileReturnsFalse(self):
    self.assertFalse(
        source_utils.guess_is_tensorflow_py_library(self.curr_file_path))

  def testSourceUtilModuleReturnsTrue(self):
    self.assertTrue(
        source_utils.guess_is_tensorflow_py_library(source_utils.__file__))

  @test_util.run_v1_only("Tensor.op is not available in TF 2.x")
  def testFileInPythonKernelsPathReturnsTrue(self):
    x = constant_op.constant(42.0, name="x")
    self.assertTrue(
        source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0]))

  def testDebuggerExampleFilePathReturnsFalse(self):
    self.assertFalse(
        source_utils.guess_is_tensorflow_py_library(os.path.normpath(
            "site-packages/tensorflow/python/debug/examples/debug_mnist.py")))
    self.assertFalse(
        source_utils.guess_is_tensorflow_py_library(os.path.normpath(
            "site-packages/tensorflow/python/debug/examples/v1/example_v1.py")))
    self.assertFalse(
        source_utils.guess_is_tensorflow_py_library(os.path.normpath(
            "site-packages/tensorflow/python/debug/examples/v2/example_v2.py")))
    self.assertFalse(
        source_utils.guess_is_tensorflow_py_library(os.path.normpath(
            "site-packages/tensorflow/python/debug/examples/v3/example_v3.py")))

  def testReturnsFalseForNonPythonFile(self):
    self.assertFalse(
        source_utils.guess_is_tensorflow_py_library(
            os.path.join(os.path.dirname(self.curr_file_path), "foo.cc")))

  def testReturnsFalseForStdin(self):
    self.assertFalse(source_utils.guess_is_tensorflow_py_library("<stdin>"))

  def testReturnsFalseForEmptyFileName(self):
    self.assertFalse(source_utils.guess_is_tensorflow_py_library(""))


class SourceHelperTest(test_util.TensorFlowTestCase):

  def createAndRunGraphHelper(self):
    """Create and run a TensorFlow Graph to generate debug dumps.

    This is intentionally done in separate method, to make it easier to test
    the stack-top mode of source annotation.
    """

    self.dump_root = self.get_temp_dir()
    self.curr_file_path = os.path.abspath(
        tf_inspect.getfile(tf_inspect.currentframe()))

    # Run a simple TF graph to generate some debug dumps that can be used in
    # source annotation.
    with session.Session() as sess:
      self.u_init = constant_op.constant(
          np.array([[5.0, 3.0], [-1.0, 0.0]]), shape=[2, 2], name="u_init")
      self.u_init_line_number = line_number_above()

      self.u = variables.Variable(self.u_init, name="u")
      self.u_line_number = line_number_above()

      self.v_init = constant_op.constant(
          np.array([[2.0], [-1.0]]), shape=[2, 1], name="v_init")
      self.v_init_line_number = line_number_above()

      self.v = variables.Variable(self.v_init, name="v")
      self.v_line_number = line_number_above()

      self.w = math_ops.matmul(self.u, self.v, name="w")
      self.w_line_number = line_number_above()

      self.evaluate(self.u.initializer)
      self.evaluate(self.v.initializer)

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
      run_metadata = config_pb2.RunMetadata()
      sess.run(self.w, options=run_options, run_metadata=run_metadata)

      self.dump = debug_data.DebugDumpDir(
          self.dump_root, partition_graphs=run_metadata.partition_graphs)
      self.dump.set_python_graph(sess.graph)

  def setUp(self):
    self.createAndRunGraphHelper()
    self.helper_line_number = line_number_above()

  def tearDown(self):
    if os.path.isdir(self.dump_root):
      file_io.delete_recursively(self.dump_root)
    ops.reset_default_graph()

  def testAnnotateWholeValidSourceFileGivesCorrectResult(self):
    source_annotation = source_utils.annotate_source(self.dump,
                                                     self.curr_file_path)

    self.assertIn(self.u_init.op.name,
                  source_annotation[self.u_init_line_number])
    self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
    self.assertIn(self.v_init.op.name,
                  source_annotation[self.v_init_line_number])
    self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
    self.assertIn(self.w.op.name, source_annotation[self.w_line_number])

    # In the non-stack-top (default) mode, the helper line should be annotated
    # with all the ops as well.
    self.assertIn(self.u_init.op.name,
                  source_annotation[self.helper_line_number])
    self.assertIn(self.u.op.name, source_annotation[self.helper_line_number])
    self.assertIn(self.v_init.op.name,
                  source_annotation[self.helper_line_number])
    self.assertIn(self.v.op.name, source_annotation[self.helper_line_number])
    self.assertIn(self.w.op.name, source_annotation[self.helper_line_number])

  def testAnnotateWithStackTopGivesCorrectResult(self):
    source_annotation = source_utils.annotate_source(
        self.dump, self.curr_file_path, file_stack_top=True)

    self.assertIn(self.u_init.op.name,
                  source_annotation[self.u_init_line_number])
    self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
    self.assertIn(self.v_init.op.name,
                  source_annotation[self.v_init_line_number])
    self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
    self.assertIn(self.w.op.name, source_annotation[self.w_line_number])

    # In the stack-top mode, the helper line should not have been annotated.
    self.assertNotIn(self.helper_line_number, source_annotation)

  def testAnnotateSubsetOfLinesGivesCorrectResult(self):
    source_annotation = source_utils.annotate_source(
        self.dump,
        self.curr_file_path,
        min_line=self.u_line_number,
        max_line=self.u_line_number + 1)

    self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
    self.assertNotIn(self.v_line_number, source_annotation)

  def testAnnotateDumpedTensorsGivesCorrectResult(self):
    source_annotation = source_utils.annotate_source(
        self.dump, self.curr_file_path, do_dumped_tensors=True)

    # Note: Constant Tensors u_init and v_init may not get dumped due to
    #   constant-folding.
    self.assertIn(self.u.name, source_annotation[self.u_line_number])
    self.assertIn(self.v.name, source_annotation[self.v_line_number])
    self.assertIn(self.w.name, source_annotation[self.w_line_number])

    self.assertNotIn(self.u.op.name, source_annotation[self.u_line_number])
    self.assertNotIn(self.v.op.name, source_annotation[self.v_line_number])
    self.assertNotIn(self.w.op.name, source_annotation[self.w_line_number])

    self.assertIn(self.u.name, source_annotation[self.helper_line_number])
    self.assertIn(self.v.name, source_annotation[self.helper_line_number])
    self.assertIn(self.w.name, source_annotation[self.helper_line_number])

  def testCallingAnnotateSourceWithoutPythonGraphRaisesException(self):
    self.dump.set_python_graph(None)
    with self.assertRaises(ValueError):
      source_utils.annotate_source(self.dump, self.curr_file_path)

  def testCallingAnnotateSourceOnUnrelatedSourceFileDoesNotError(self):
    # Create an unrelated source file.
    fd, unrelated_source_path = tempfile.mkstemp()
    with open(fd, "wt") as source_file:
      source_file.write("print('hello, world')\n")

    self.assertEqual({},
                     source_utils.annotate_source(self.dump,
                                                  unrelated_source_path))

    # Clean up unrelated source file.
    os.remove(unrelated_source_path)

  def testLoadingPythonSourceFileWithNonAsciiChars(self):
    fd, source_path = tempfile.mkstemp()
    with open(fd, "wb") as source_file:
      source_file.write(u"print('\U0001f642')\n".encode("utf-8"))
    source_lines, _ = source_utils.load_source(source_path)
    self.assertEqual(source_lines, [u"print('\U0001f642')", u""])
    # Clean up unrelated source file.
    os.remove(source_path)

  def testLoadNonexistentNonParPathFailsWithIOError(self):
    bad_path = os.path.join(self.get_temp_dir(), "nonexistent.py")
    with self.assertRaisesRegex(IOError,
                                "neither exists nor can be loaded.*par.*"):
      source_utils.load_source(bad_path)

  def testLoadingPythonSourceFileInParFileSucceeds(self):
    # Create the .par file first.
    temp_file_path = os.path.join(self.get_temp_dir(), "model.py")
    with open(temp_file_path, "wb") as f:
      f.write(b"import tensorflow as tf\nx = tf.constant(42.0)\n")
    par_path = os.path.join(self.get_temp_dir(), "train_model.par")
    with zipfile.ZipFile(par_path, "w") as zf:
      zf.write(temp_file_path, os.path.join("tensorflow_models", "model.py"))

    source_path = os.path.join(par_path, "tensorflow_models", "model.py")
    source_lines, _ = source_utils.load_source(source_path)
    self.assertEqual(
        source_lines, ["import tensorflow as tf", "x = tf.constant(42.0)", ""])

  def testLoadingPythonSourceFileInParFileFailsRaisingIOError(self):
    # Create the .par file first.
    temp_file_path = os.path.join(self.get_temp_dir(), "model.py")
    with open(temp_file_path, "wb") as f:
      f.write(b"import tensorflow as tf\nx = tf.constant(42.0)\n")
    par_path = os.path.join(self.get_temp_dir(), "train_model.par")
    with zipfile.ZipFile(par_path, "w") as zf:
      zf.write(temp_file_path, os.path.join("tensorflow_models", "model.py"))

    source_path = os.path.join(par_path, "tensorflow_models", "nonexistent.py")
    with self.assertRaisesRegex(IOError,
                                "neither exists nor can be loaded.*par.*"):
      source_utils.load_source(source_path)


@test_util.run_v1_only("Sessions are not available in TF 2.x")
class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):

  def createAndRunGraphWithWhileLoop(self):
    """Create and run a TensorFlow Graph with a while loop to generate dumps."""

    self.dump_root = self.get_temp_dir()
    self.curr_file_path = os.path.abspath(
        tf_inspect.getfile(tf_inspect.currentframe()))

    # Run a simple TF graph to generate some debug dumps that can be used in
    # source annotation.
    with session.Session() as sess:
      loop_body = lambda i: math_ops.add(i, 2)
      self.traceback_first_line = line_number_above()

      loop_cond = lambda i: math_ops.less(i, 16)

      i = constant_op.constant(10, name="i")
      loop = while_loop.while_loop(loop_cond, loop_body, [i])

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
      run_metadata = config_pb2.RunMetadata()
      sess.run(loop, options=run_options, run_metadata=run_metadata)

      self.dump = debug_data.DebugDumpDir(
          self.dump_root, partition_graphs=run_metadata.partition_graphs)
      self.dump.set_python_graph(sess.graph)

  def setUp(self):
    self.createAndRunGraphWithWhileLoop()

  def tearDown(self):
    if os.path.isdir(self.dump_root):
      file_io.delete_recursively(self.dump_root)
    ops.reset_default_graph()

  def testGenerateSourceList(self):
    source_list = source_utils.list_source_files_against_dump(self.dump)

    # Assert that the file paths are sorted and unique.
    file_paths = [item[0] for item in source_list]
    self.assertEqual(sorted(file_paths), file_paths)
    self.assertEqual(len(set(file_paths)), len(file_paths))

    # Assert that each item of source_list has length 6.
    for item in source_list:
      self.assertTrue(isinstance(item, tuple))
      self.assertEqual(6, len(item))

    # The while loop body should have executed 3 times. The following table
    # lists the tensors and how many times each of them is dumped.
    #   Tensor name            # of times dumped:
    #   i:0                    1
    #   while/Enter:0          1
    #   while/Merge:0          4
    #   while/Merge:1          4
    #   while/Less/y:0         4
    #   while/Less:0           4
    #   while/LoopCond:0       4
    #   while/Switch:0         1
    #   while/Switch:1         3
    #   while/Identity:0       3
    #   while/Add/y:0          3
    #   while/Add:0            3
    #   while/NextIteration:0  3
    #   while/Exit:0           1
    # ----------------------------
    #   (Total)                39
    #
    # The total number of nodes is 12.
    # The total number of tensors is 14 (2 of the nodes have 2 outputs:
    #   while/Merge, while/Switch).

    _, is_tf_py_library, num_nodes, num_tensors, num_dumps, first_line = (
        source_list[file_paths.index(self.curr_file_path)])
    self.assertFalse(is_tf_py_library)
    self.assertEqual(12, num_nodes)
    self.assertEqual(14, num_tensors)
    self.assertEqual(39, num_dumps)
    self.assertEqual(self.traceback_first_line, first_line)

  def testGenerateSourceListWithNodeNameFilter(self):
    source_list = source_utils.list_source_files_against_dump(
        self.dump, node_name_regex_allowlist=r"while/Add.*")

    # Assert that the file paths are sorted.
    file_paths = [item[0] for item in source_list]
    self.assertEqual(sorted(file_paths), file_paths)
    self.assertEqual(len(set(file_paths)), len(file_paths))

    # Assert that each item of source_list has length 4.
    for item in source_list:
      self.assertTrue(isinstance(item, tuple))
      self.assertEqual(6, len(item))

    # Due to the node-name filtering the result should only contain 2 nodes
    # and 2 tensors. The total number of dumped tensors should be 6:
    #   while/Add/y:0          3
    #   while/Add:0            3
    _, is_tf_py_library, num_nodes, num_tensors, num_dumps, _ = (
        source_list[file_paths.index(self.curr_file_path)])
    self.assertFalse(is_tf_py_library)
    self.assertEqual(2, num_nodes)
    self.assertEqual(2, num_tensors)
    self.assertEqual(6, num_dumps)

  def testGenerateSourceListWithPathRegexFilter(self):
    curr_file_basename = os.path.basename(self.curr_file_path)
    source_list = source_utils.list_source_files_against_dump(
        self.dump,
        path_regex_allowlist=(".*" + curr_file_basename.replace(".", "\\.") +
                              "$"))

    self.assertEqual(1, len(source_list))
    (file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps,
     first_line) = source_list[0]
    self.assertEqual(self.curr_file_path, file_path)
    self.assertFalse(is_tf_py_library)
    self.assertEqual(12, num_nodes)
    self.assertEqual(14, num_tensors)
    self.assertEqual(39, num_dumps)
    self.assertEqual(self.traceback_first_line, first_line)


if __name__ == "__main__":
  googletest.main()
